#!/bin/bash
export CUDA_DEVICE_ORDER=PCI_BUS_ID
export CUDA_VISIBLE_DEVICES=0

log_dir="./results/log"
log_file="$log_dir/inference_geo880_0.out"
mkdir -p "$log_dir"
exec >"$log_file" 2>&1

apply_chat_template=false
datasets=("geo880-standard")
subset_sizes=(100)
modelabbrs=("llama-3.1-8b")
exp_num=10
metric="cosine_similarity"
ks=(4 8)
methods=("knn")
dp_choices=("knn")
embs=("all-roberta-large-v1")

clean_gpu_cache() {
    nvidia-smi --gpu-reset
    sleep 2
}

total_num=0
for subset_size in "${subset_sizes[@]}"; do
    for modelabbr in "${modelabbrs[@]}"; do
        for dataset in "${datasets[@]}"; do
            for emb in "${embs[@]}"; do
                for k in "${ks[@]}"; do
                    for method in "${methods[@]}"; do
                        if [[ "$method" == *"knn"* ]]; then
                            exp_num_method=1
                        else
                            exp_num_method=$exp_num
                        fi
                        total_num=$((total_num + exp_num_method))
                    done
                done
            done
        done
    done
done

echo "Total number of runs: $total_num"

target_num=-1

current_num=0
for subset_size in "${subset_sizes[@]}"; do
    for modelabbr in "${modelabbrs[@]}"; do
        for data in "${datasets[@]}"; do
            for emb in "${embs[@]}"; do
                for k in "${ks[@]}"; do
                    for method in "${methods[@]}"; do
                        if [[ "$method" == *"knn"* ]]; then
                            exp_num_method=1
                        else
                            exp_num_method=$exp_num
                        fi
                        for ((i=0; i<exp_num_method; i++)); do
                            ((current_num++))
                            if [[ "$current_num" -ge "$target_num" ]]; then
                                if [[ "$method" == *"knn"* || "$method" == "random" || "$method" == *"diversity"* || "$method" == "compute_relation" ]]; then
                                        echo "Current run status: $current_num / $total_num"
                                        # change the modelname here
                                        modelname="/home/amax/exp/huggingface/transformers/${modelabbr}"
                                        while true; do
                                            timeout 600 bash scripts/run_inference.sh "$modelname" "$modelabbr" "$data" greedy \
                                                --subset_size "$subset_size" \
                                                --k "$k" \
                                                --exp_num "$i" \
                                                --method "$method" \
                                                --emb "$emb" \
                                                --metric "$metric" \
                                                --apply_chat_template "$apply_chat_template"
                                            if [ $? -eq 124 ]; then
                                                echo "Command execution timed out, cleaning GPU cache and retrying..."
                                                clean_gpu_cache
                                                continue
                                            else
                                                break
                                            fi
                                        done
                                fi
                            fi
                        done
                    done
                done
            done
        done
    done
done